Experiments on Snake Activation#

To induce periodic extrapolation bias in neural networks, Ziyin, et al, 2020 proposed a simple activation function called “Snake activation” with the form \(x + \frac{1}{a}sin^2(ax)\) where \(a\) can be treated as a constant hyperparameter or learned parameter.

We’ve experimented on the Snake activation…

Extrapolation Experiment#

We generated synthetic data using sin(x) function, which we aim to learn. The blue colors are the training set which we will use to train our model. The orange colors are the test set which we will use to check if our model can generalize and did learn the sine function

Our inputs are the x-values (horizontal axis) and our targets are y = sin(x), which we will train our model to predict given x.

../_images/13dbf81813a1c3690e4a7c805ac0ac5a7830a172e84f28c1971190305c2886c8.png

We show an animation below of the neural network parameters and the evolution of how it fits the data over its training (epochs).

In these experiments, we used

  • Xavier Normal initialization

  • Two hidden layers

  • 512 neurons per hidden layer

  • learning rate = 0.0001

  • a = 30 for the activation functions

Interpolation Experiment#

Now what if we make this an interpolation problem instead of an extrapolation problem. In other words, what if we reverse the train and test set? Will the model be able to infer the sine wave in between the test data? We show below the sine wave colored by train (blue) and test (orange)

Note that we use the term “interpolate” loosely here. This is technically still an extrapolation problem since the distribution of the test set is not within the support of the distribution of the training set.

../_images/905b21915514330024b73b15a3e1b52b529d0c21fbc2782d7b05c4493ceb64d3.png

Extrapolation investigation#

def gen_data2(L=1000, prop_train=0.5, start=-30*torch.pi, end=30*torch.pi, reverse=False):

    x = torch.linspace(start, end, L)
    y = torch.sin(x) # + torch.sin(x/3) + torch.sin(x*3)

    cnt_train = int(L * prop_train)

    train_inds = [(L//8, (L//8 + cnt_train//2)), (L//8 + cnt_train//2 + L//8, L//8 + cnt_train//2 + L//8 + cnt_train//2)]
    train_inds = (list(np.arange(train_inds[0][0], train_inds[0][1], 1)) 
                  + list(np.arange(train_inds[1][0], train_inds[1][1], 1)))
    
    test_inds = [i for i in np.arange(L) if i not in train_inds]

    x_train = x[train_inds]
    y_train = y[train_inds]

    x_test = x[test_inds]
    y_test = y[test_inds]

    if reverse: # reverse train and test
        return x_test, y_test, x_train, y_train
    else:
        return x_train, y_train, x_test, y_test
x_train, y_train, x_test, y_test = gen_data2(L=2000, prop_train=0.5, start=-50, end=50, reverse=False)

fig, axs = plt.subplots(1, 1, figsize=(12, 2))
plt.scatter(x_train, y_train, label='train', s=2)
plt.scatter(x_test, y_test, label='test', s=2)
plt.legend()
plt.tight_layout()
axs.spines[['right', 'top']].set_visible(False)
../_images/f972d16751305fd8b6ad95b7771403dc0fabd482f5d18b1da67ca165266ff7e7.png
# for animation
fig, axs = plt.subplots(1, 1, figsize=(10, 2)) # dpi=200)
../_images/a3c6bf0178b94218c3407006c24376295e3b217d3130647e514cf9388774538e.png
snake_a = 30

model = MLP(count_nodes, snake_a=snake_a)
model = model.to(device)
optim = torch.optim.Adam(model.parameters(),lr=lr)
# scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size= 1000)

anim = animation.FuncAnimation(fig,
                               animate,
                               frames=frames,
                               interval=1)
HTML(anim.to_jshtml(fps=10))

Fitting on a Decaying Signal#

../_images/222172357c0b2a219c9f9317f4cdfb73a785477b19ce835c939d5b1c4200ed0c.png

Different coefficients 1 / b sin( a x)#

TSNE on learned features#

Take-aways#

We performed…